import safety_gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import time
import numpy as np
import torch
import torch.nn as nn
import gym
import sys
import os
os.chdir('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-VAL/')
sys.path.append('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-TRAIN/')
import  core_fast as core
# from utils.logx import EpochLogger
from utils.mpi_tools import mpi_fork, proc_id, num_procs, mpi_sum
torch.autograd.set_detect_anomaly(True)
import sysv_ipc
import torch.nn.functional as F
import copy
import multiprocessing
import pandas as pd
import gc

class Safety_NN(nn.Module):
    def __init__(self, n_state, n_class):
        super(Safety_NN, self).__init__()
        self.layer1 = nn.Linear(n_state, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_class)
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)


n_sample_of_action = 252
n_class = 2
n_rank = 2
n_action = 2
n_NN = 1
n_observations = 60
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.float64 

def validation_for_one_process(thread_order,  agent, n_epoch, xi, th, default_action_margin, context, SNN_list,
                               storage_intest_label_list, result_intest_list, local_steps_per_epoch, render, hazard_check_margin, max_ep_len, val_data_path):
    print(str(thread_order)+"start")
    env_name = 'Safexp-PointGoal1-v0'
    env = gym.make(env_name)
    # global variable section copied from cpp.
    normt_init = True # only for 1 condition checking
    normaltable_global = np.zeros((n_NN,2,3,2), dtype=int)
    posterior_global = np.zeros((n_NN,3,2), dtype=float)
    normaltable_after_th_operation = np.zeros((n_NN,2,3,2), dtype=int)
    posterior_after_th_operation = np.zeros((n_NN,3,2), dtype=float)
    ans_sum = np.zeros((n_NN,2), dtype=int)
    # IPC setting
    tmp_action_list = []
    for i in torch.linspace(-3, 3, 376): 
        for j in torch.linspace(-3, 3, 376): 
            tmp_action_list.append([i,j])
    a_candidate = torch.tensor(tmp_action_list, dtype=float)



    def pof_calc_posterior(normaltable_tmp, posterior_tmp, n_intest_safe, n_intest_unsafe):
        for i in range(n_NN):
            prior = np.zeros(2, dtype=float)
            prior[0] = n_intest_safe/(n_intest_safe + n_intest_unsafe)
            prior[1] = n_intest_unsafe/(n_intest_safe + n_intest_unsafe)
            for j in range(n_class):
                denominator = 0.00000000001
                for k in range(n_rank): denominator += prior[k] * normaltable_tmp[i][k][j][0] / (ans_sum[i][k] + 0.00000000001)
                for k in range(n_rank): posterior_tmp[i][j][k] = ((prior[k] * normaltable_tmp[i][k][j][1] + 0.00000000001) / (ans_sum[i][k] + 0.00000000001)+ 0.00000000001) / denominator
                for k in range(n_rank): 
                    if posterior_tmp[i][j][k]>1: posterior_tmp[i][j][k]=1


    def pof_calc_normaltable_and_posterior(normaltable_tmp, posterior_tmp, xi, result_intest_list, n_intest_safe, n_intest_unsafe):
        # calculate normaltable
        for i in range(n_NN): 
            n_test_case_pm = n_intest_safe + n_intest_unsafe
            for j in range(n_test_case_pm):
                max = -5000000
                secmax = -5000000
                ans_tmp = int(result_intest_list[i][j][0])
                ans_sum[i][ans_tmp] += 1
                for k in range(1,n_class+1):
                    if result_intest_list[i][j][k] > max:
                        secmax = max
                        max = result_intest_list[i][j][k]
                    elif result_intest_list[i][j][k] > secmax: secmax = result_intest_list[i][j][k]
                for k in range(1,n_class+1):
                    if result_intest_list[i][j][k] + xi > max: normaltable_tmp[i][ans_tmp][k-1][1] += 1
                    if result_intest_list[i][j][k] - xi > secmax: normaltable_tmp[i][ans_tmp][k-1][0] += 1
        pof_calc_posterior(normaltable_tmp, posterior_tmp, n_intest_safe, n_intest_unsafe)
        
        
    def pof_th_operation(result_intest_list, xi, th):
        nonlocal normaltable_after_th_operation, posterior_after_th_operation
        pof_biaslog_path = val_data_path + "BIAS_LOG"+str(thread_order)+".txt"
        biaslogfile = open(pof_biaslog_path, "a")
        # find minimum posterior and corresponding output
        result_intest_list_tmp = copy.deepcopy(result_intest_list)
        over_th = False
        min_output_class_tmp = -1
        min_posterior = 1e9
        for j in range(n_class):
            if posterior_global[0][j][1] < min_posterior:
                min_posterior = posterior_global[0][j][1]
                min_output_class_tmp = j
        if (min_output_class_tmp == -1): 
            print("########################################### Err")
            exit()
            ##### min_posterior 터지는 상황
        # find bias for given threshold
        output_bias = 0.
        for i in range(4):
            if min_posterior > th: over_th = True
            else: over_th = False
            candidate_diff = -1.
            fastmethod = 0.01 * 2.95**(3-i)
            return_to_previous_bias = False
            while True:
                candidate_diff = abs(min_posterior - th)
                if over_th:
                    result_intest_list_tmp[0][:,min_output_class_tmp+1] -= fastmethod
                    output_bias -= fastmethod
                else:
                    result_intest_list_tmp[0][:,min_output_class_tmp+1] += fastmethod
                    output_bias += fastmethod
                normaltable_after_th_operation = np.zeros((n_NN,2,3,2), dtype=int)
                pof_calc_normaltable_and_posterior(normaltable_after_th_operation, posterior_after_th_operation, xi, result_intest_list_tmp,  int(result_intest_list[0].shape[0]/2),int(result_intest_list[0].shape[0]/2))
                min_posterior = posterior_after_th_operation[0][min_output_class_tmp][1]
                print(i, candidate_diff, abs(min_posterior - th), min_posterior, over_th, output_bias, fastmethod, return_to_previous_bias)
                if return_to_previous_bias: break
                if over_th != (min_posterior > th):
                    if abs(min_posterior - th) < candidate_diff: break
                    else: 
                        return_to_previous_bias = True
                        over_th = (min_posterior > th) 
                elif abs(output_bias) >= 20: break
                else: pass
        print("th changes!", th, min_output_class_tmp, output_bias)
        biaslogfile.write("xi: "+str(xi)+", th: "+str(th) + "\n")
        biaslogfile.write("min_output_class: "+str(min_output_class_tmp)+ "\n")
        biaslogfile.write("final_bias: "+str(output_bias)+", final (min_posterior - th): "+str(min_posterior - th) + "\n")
        biaslogfile.write("All class posterior probability for label 1: "+str(posterior_after_th_operation[0][min_output_class_tmp][1])+ " "+str(posterior_after_th_operation[0][min_output_class_tmp][1])+"\n")
        biaslogfile.close()
        return min_output_class_tmp, output_bias


    def pof_validation(state_reshape, x_vel, hazard_check, mu, std, xi, th):
        nonlocal storage_intest_label_list, result_intest_list
        nonlocal normt_init
        nonlocal min_output_class, output_bias_for_th
        global n_sample_of_action
        result_batch_list = list()
        if hazard_check and x_vel < 0: selected_action=-20
        elif hazard_check and x_vel > 0: selected_action=-102
        elif x_vel < -default_action_margin:selected_action=-102
        elif x_vel > default_action_margin: selected_action=-20
        else: selected_action=-61
        MU  = torch.as_tensor(mu,  device=device, dtype=dtype)  # shape [2]
        STD = torch.as_tensor(std, device=device, dtype=dtype)  # shape [2]
        for nn_order in range(n_NN):
            state_action_batch_temp = np.concatenate((np.repeat(state.reshape(1,-1), repeats=n_sample_of_action, axis=0), a_candidate_dense.reshape(n_sample_of_action,-1)), axis=1)
            state_action_batch_temp[:,60:62]=np.clip(state_action_batch_temp[:,60:62],-1,1)
            for i in range(state_action_batch_temp.shape[0]):
                if state_action_batch_temp[i][60]>0: state_action_batch_temp[i][60]=1.0
                else: state_action_batch_temp[i][60]=-1.0
            result_sample_array=SNN_list[nn_order](torch.FloatTensor(state_action_batch_temp).to(device))
        if normt_init: # only once
            pof_calc_normaltable_and_posterior(normaltable_global, posterior_global, xi, result_intest_list, int(result_intest_list[0].shape[0]/2),int(result_intest_list[0].shape[0]/2) )
            normt_init = False
        if min_output_class < 0: min_output_class, output_bias_for_th = pof_th_operation(result_intest_list, xi, th)
        result_sample_array[:,min_output_class] += output_bias_for_th
        cand = torch.as_tensor(a_candidate, device=device, dtype=dtype).view(376, 376, 2)
        candlogp = -((cand[:, :, 0]-MU[0])**2)/(STD[0]**2) - ((cand[:, :, 1]-MU[1])**2)/(STD[1]**2)
        sc = (result_sample_array[:, 0] <= result_sample_array[:, 1]).to(int) 
        if sc.sum()==252*(1-min_output_class):
            return selected_action, 1
            # same penalties, using bool -> float cast once
        NEG = -10000000000.0
        candlogp[:188, 125:251] += (sc[:126] != min_output_class).to(dtype).expand(188, 126) * NEG
        candlogp[:188, :125]    += (sc[0]   != min_output_class).to(dtype).expand(188, 125) * NEG
        candlogp[:188, 251:]    += (sc[125] != min_output_class).to(dtype).expand(188, 125) * NEG
        candlogp[188:, 125:251] += (sc[126:]!= min_output_class).to(dtype).expand(188, 126) * NEG
        candlogp[188:, :125]    += (sc[126] != min_output_class).to(dtype).expand(188, 125) * NEG
        candlogp[188:, 251:]    += (sc[251] != min_output_class).to(dtype).expand(188, 125) * NEG

        flat = candlogp.flatten()
        max_log = flat.max()
        weights = torch.exp(flat - max_log)
        cdf = torch.cumsum(weights, dim=0)
        r = torch.empty((), dtype=torch.float64, device=device).uniform_(0, cdf[-1]) #cdf[-1]*(int(cdf[-1]+5)%10)/10.0#
        selected_action  = int(torch.searchsorted(cdf, r, right=True).item())

        return selected_action, 0

    pof_cost_path = val_data_path + "POF_COST"+str(thread_order)+".txt"
    pof_reward_path = val_data_path + "POF_PF"+str(thread_order)+".txt"
    pof_cv_path = val_data_path + "POF_CV"+str(thread_order)+".txt"
    pof_act_path = val_data_path + "POF_ACT"+str(thread_order)+".txt"
    costfile = open(pof_cost_path, "a")
    pffile = open(pof_reward_path, "a")
    cvfile = open(pof_cv_path, "a")
    acfile = open(pof_act_path, "a")
    
    state, reward, done, cost, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    acc_cost = 0
    acc_reward = 0
    min_output_class = -1
    output_bias_for_th = 0.
    for epoch in range(n_epoch):
        cv_denom = [0]
        cv_denom_counter = 0
        cv_counter = 0
        cv_index = 0
        for t in range(local_steps_per_epoch):
            if render and proc_id()==0: env.render()
            a_candidate_dense, v, vc, _, mu, std = agent.stepv2(torch.as_tensor(state, dtype=torch.float32))
            state_reshape = state.reshape((1,)+state.shape)
            x_vel = state[57]
            hazard_check = not all(one < hazard_check_margin for one in state[22:38]) # self.observe_hazards
            selected_action, default_action_using = pof_validation(state_reshape, x_vel, hazard_check, mu, std, xi, th)
            if default_action_using == "0": cv_denom.append(cv_denom[cv_index] + 1)
            else: 
                cv_denom.append(cv_denom[cv_index])
            
            if selected_action == -20:
                action = [-1.,0.]
                selected_action = -1000
            elif selected_action == -102:
                action = [1.,0.]
                selected_action = -2000
            elif selected_action == -61:
                action = [0.,0.]
                selected_action = -3000
            else:
                assert(selected_action >= 0)
                action = a_candidate[selected_action]
            acfile.write(str(selected_action) + "\n")
            next_state, reward, done, info = env.step(action)
            cost = info.get('cost', 0)
            acc_cost += cost
            acc_reward += reward
            state = next_state
            ep_ret += reward
            ep_cost += cost
            ep_len += 1
            if cost != 0: 
                cv_counter += cv_denom[cv_index + 1] - cv_denom[max(0, cv_index - 59)]
                cv_denom_counter += cv_denom[cv_index + 1]
                cv_denom = [0]
                cv_index = 0
                state, reward, done, cost, _, _, _ = env.reset(), 0, False, 0, 0, 0, 0
            else: cv_index += 1

            terminal = done or (ep_len == max_ep_len)
            if terminal:
                print("RESET at epoch:%d, local_epoch:%d" %(epoch, t+1))
                if ep_len == max_ep_len:
                    costfile.write(str(acc_cost) + "\n")
                    pffile.write(str(acc_reward) + "\n")
                    if cost == 0: cv_denom_counter+=cv_denom[cv_index]
                    cvfile.write(str(float(cv_counter/(cv_denom_counter+0.00000000001))) + "\n")
                    acc_cost = 0
                    acc_reward = 0
                    cv_denom = [0]
                    cv_denom_counter = 0
                    cv_counter = 0
                    cv_index = 0
                    costfile.close()
                    pffile.close()
                    cvfile.close()
                    acfile.close()
                    costfile = open(pof_cost_path, "a")
                    pffile = open(pof_reward_path, "a")
                    cvfile = open(pof_cv_path, "a")
                    acfile = open(pof_act_path, "a")
                state, reward, done, cost, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    costfile.close()
    pffile.close()
    cvfile.close()
    acfile.close()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Safexp-PointGoal1-v0')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=10000) #####
    parser.add_argument('--local_steps_per_epoch', type=int, default=1000)
    parser.add_argument('--len', type=int, default=1000)
    parser.add_argument('--exp_name', type=str, default='error')
    parser.add_argument('--checkpoint', type=str, default='-1')
    parser.add_argument('--render', action='store_true')
    args = parser.parse_args()
    
    

    from utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    
    # Seed setting
    seed = args.seed
    seed += 10000 * proc_id()
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == 'cuda': torch.cuda.manual_seed(seed)
    
    # ETC setting
    env_fn = lambda: gym.make(args.env)
    ac_kwargs = dict(hidden_sizes=[args.hid]*args.l)
    render = args.render
    epochs = args.epochs
    local_steps_per_epoch = args.local_steps_per_epoch
    max_ep_len = args.len
    storage_intest_number_array = [5000000, 5000000]
    default_action_margin = 0.05
    hazard_check_margin = 0.935
    
    context = [60*2.5,1000-60*2.5] #####
    
    # Main setting
    exp_name_split = args.exp_name.split('_')
    val_info = exp_name_split[0]
    ckpt_info = exp_name_split[1]
    agent_info = exp_name_split[2]
    task_info = exp_name_split[3]
    '''if len(exp_name_split) == 5: agent_info_sub = exp_name_split[4]
    else: '''
    agent_info_sub = None
    '''if len(exp_name_split) == 6:
        intest_scale = int(exp_name_split[5])
        storage_intest_number_array = [intest_scale, intest_scale] #####
    else: intest_scale = None'''
    ckpt_info_sub = args.checkpoint
    
    # Path setting
    val_data_path = args.exp_name + "/validation_" + ckpt_info_sub + "/"
    if not os.path.exists(val_data_path):
        os.makedirs(val_data_path)
        print("    CREATE DIRECTORY %s" %(val_data_path))
    else:
        print("    DIRECTORY ALREADY EXISTS %s" %(val_data_path))
        exit()
    train_default_path = "../PPO-POINT-TRAIN/" + agent_info + "_" + task_info + "_"
    if agent_info_sub != None: train_default_path += (agent_info_sub + "/")
    
    # Agent setting
    if agent_info == "ppo":
        actor_critic=core.MLPActorCritic_ppo_point_train
        agent = actor_critic(env_fn().observation_space,env_fn().action_space, **ac_kwargs)
        agent.eval() 
    else: exit()
    # Validation setting
    if val_info == "pofval": 
        SNN_list = [Safety_NN(n_observations+n_action, n_class) for _ in range(n_NN)]
        for i in range(n_NN): SNN_list[i].eval()
        storage_intest_number = storage_intest_number_array[0] + storage_intest_number_array[1]
        storage_intest_state_list = []
        storage_intest_label_list =[]
        safe_intest_index_list = [list() for _ in range(n_NN)]
        unsafe_intest_index_list = [list() for _ in range(n_NN)]
    else: exit() 
    # Checkpoint setting
    if ckpt_info == "poftraining":
        ckpt_path = args.exp_name + "/checkpoint/pof-checkpoint-1_epoch-" + ckpt_info_sub + ".pt"
        try: # (1)agent, (2)intest, (3)classifier
            backup = torch.load(ckpt_path)
            assert storage_intest_number == backup["storage_intest_number"] #####
            storage_intest_state_list = np.array([backup["storage_intest_state"]])
            storage_intest_label_list = np.array([backup["storage_intest_label"]])
            print(storage_intest_state_list.shape)
            agent.load_state_dict(backup["ac_ppo"])
            agent.eval() 
            for i in range(n_NN):
                temp_name = "SNN"
                SNN_list[i].load_state_dict(backup[temp_name])
                SNN_list[i].eval()
            print("    LOAD POF %s" %(ckpt_path))
        except Exception as e:
            print(e)
            exit()
    # Run
    if val_info == "pofval":
        torch.multiprocessing.set_start_method('spawn')    
        start_time = time.time()
        result_intest_list = list()
        for nn_order in range(n_NN):
            storage_intest_state_list[nn_order][:,60:62]=np.clip(storage_intest_state_list[nn_order][:,60:62],-1,1)
            for i in range(storage_intest_state_list[nn_order].shape[0]):
                if storage_intest_state_list[nn_order][i][60]>0: storage_intest_state_list[nn_order][i][60]=1.0
                else: storage_intest_state_list[nn_order][i][60]=-1.0
            result_intest_tmp = SNN_list[nn_order].to(device)(torch.FloatTensor(storage_intest_state_list[nn_order]).to(device)).detach().cpu().numpy()/2.
            result_intest_list.append(np.concatenate((storage_intest_label_list[nn_order].reshape(10000000,1), result_intest_tmp), axis=1))
            gc.collect()
            torch.cuda.empty_cache()
        xi_array = [10 / (storage_intest_number_array[0] / 5000)] #####
        n_xi = len(xi_array) 
        th_array = [0.00001] 
        n_th = len(th_array)
        n_total_process = n_xi * n_th
        num_cores = 10
        procs = []
        for thread_order in range(n_total_process):
            p = multiprocessing.Process(target=validation_for_one_process, args=(thread_order, agent, epochs, xi_array[thread_order//n_th], th_array[thread_order%n_th],
                                                                                default_action_margin, context, SNN_list, storage_intest_label_list, result_intest_list, 
                                                                                local_steps_per_epoch, render, hazard_check_margin, max_ep_len, val_data_path))
            p.start()
            procs.append(p)
            print("thread_order: ", thread_order)
            while (len(procs) >= num_cores): procs.pop(0).join()
        for p in procs:
            p.join()
        print(time.time()-start_time)
    else: exit() 